import torch


def accuracy(dt, gt):  # top-1 acc
    return (dt == torch.argmax(gt, dim=1)).sum() / dt.size(0)
